seq_q = torch.tensor([[6, 1, 2, 3, 5],
                      [6, 1, 2, 3, 4]])
seq_k = torch.tensor([[1, 2, 3, 5, 0],
                       [1, 2, 3, 4, 0]]) 

# seq_q:[batch_size,len_q] (2,5)
# seq_k:[batch_size,len_k] (2,5)
self_attn_mask=get_attn_pad_mask(seq_q,seq_k)#self_attn_mask:(2,5,5)
print(self_attn_mask)

n_heads=8
attn_mask = self_attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
# attn_mask : [batch_size, n_heads, len_q, len_k] (2,8,5,5) 
print(attn_mask)  
